from Network.network_utils import run_optimizer
import numpy as np
from ActualCausal.Train.regularizers import apply_regularizers

def generate_random_mask(args, params, model, form, batch):
    # generates a random mask that respects the train names and train types, if relevant
    random_samples = np.random.rand(len(batch),args.factor.num_objects)
    if len(args.active.random_masks.safe_random_type) > 0:
        safe_masks = compute_safe(args, params, model, batch)
        random_samples = random_samples * safe_masks
    new_random_samples = random_samples.copy()
    new_random_samples[random_samples > params.random_mask_rate] = 0
    new_random_samples[random_samples <= params.random_mask_rate] = 1
    return new_random_samples

def compute_safe(args, params, model, batch):
    # determines which objects are "safe" to mask out, under some criteria
    # 1 means that it is safe to mask out, 0 means it is not
    # TODO: add more measures of safety
    safe = np.ones(len(batch), args.factor.num_objects)
    if args.active.random_masks.safe_random_type == "nonproximal":
        # only works with target networks, not all masks
        safe = batch.proximity[:,model.target_name]
    return safe
     

def generate_function_mask(args, params, model, form, batch):
    # generates a mask that respects only masking out objects that satisfy a particular criteria
    random_samples = np.random.rand(len(batch),args.factor.num_objects)
    safe_masks = compute_safe(args, params, model, batch)
    random_samples = random_samples * safe_masks

    new_random_samples = random_samples.copy()
    new_random_samples[random_samples > params.random_mask_rate] = 0
    new_random_samples[random_samples <= params.random_mask_rate] = 1
    return new_random_samples


def train_random_active(args, params, model, buffer, form="all", log_batch=[], wrap_function=None, additional=[], itr_num=0, intermediate_logger = None):
    # performs forward model training, but randomly masks out part of the input
    # this allows methods to scale even when the data does not inherently have batch.valid style masking
    # still respects valid masking, however
     
    mask_form = "all_rand_mask" if form == "all" else "rand_mask" # TODO: random masks doesn't use probabilistic masks, just full
    for i in range(args.active.active_steps):
        batch, idxes = buffer.sample(args.train.batch_size, params.sample_active_full_weights)
        batch = wrap_function(batch) if wrap_function is not None else batch

        mask = generate_random_mask(args, params, model, form, batch)
        result = model.infer(batch, batch.valid * mask, [mask_form], log_batch=log_batch, additional=additional)
        grad_variables = [result.full_active_input] if args.active.include_gradient else list()
        compute_models, optims = model.get_model_optim([mask_form])
        optim, compute_model = optims[0], compute_models[0]

        loss = apply_regularizers(- result[mask_form].log_probs, args, params, model, batch, results=result[mask_form])
        result.reg_loss = loss
        result.gradients = run_optimizer(optim, compute_model, loss, grad_variables=grad_variables)
        if intermediate_logger is not None: intermediate_logger.log(itr_num * args.active.active_steps + i, {"rand_mask": result}, intermediate_name = "_full")
    return result